第 3 课:数据加载(上)
整个数据处理过程有两个核心——Dataset, DataLoader。
Dataset用于数据读取,是一个抽象基类,提供给用户定义自己的数据读取方式DataLoader用于数据加载,有打乱数据,均衡1:1采样,多进程数据加载,组装成Batch形式等功能。
本章将围绕着它们两个展开介绍pytorch的数据读取、预处理、加载等功能。
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
整个数据加载的流程如下,我们会逐块讲解
+-------------+ +--------------+ +-------------+ +---------+
| Dataset | ---> | Data Loader | ---> | Batch Data | ---> | Model |
+-------------+ +--------------+ +-------------+ +---------+
(一)Dataset
torch.utils.data.Dataset 是数据加载和预处理的核心模块,用于定义和管理数据集。
通过 Dataset,我们可以方便地加载自定义数据集或使用现有的标准数据集。`
Dataset 是一个抽象类,表示数据集的接口。
你需要创建自己的 数据集类 并继承自 torch.utils.data.Dataset,然后实现以下两个核心函数:
__len__():返回数据集的大小。__getitem__(index):通过索引返回数据样本和标签,并进行预处理(包括online的数据增强)
此后 Dataloader 会调用Dataset的getitem函数, 由 返回值 组合成一个样本(batch)。
1. 案例一: 数据的标签在 txt 中
这里我们使用了一个 “利用X 光判断是否有新冠肺炎” 的例子
其文件格式如下
├─imgs
│ ├─covid-19 包含与 COVID-19 感染相关的医学图像。
│ │ auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg
│ │ ryct.2020200028.fig1a.jpeg
│ │
│ └─no-finding 包含无异常(正常情况)的医学图像。
│ 00001215_000.png
│ 00001215_001.png
│
└─labels 标签文件,每个文件的每一行对应一个图像的路径及其标签。
train.txt 训练集的标签文件
valid.txt 验证集的标签文件
class COVID19Dataset(Dataset):
def __init__(self, root_dir, txt_path, transform=None):
self.root_dir = root_dir # 图像文件的根目录
self.txt_path = txt_path # 标签文件的路径
self.transform = transform # 数据预处理方法(如调整尺寸、归一化等)
self.img_info = [] # 用于存储图像路径和标签,格式为 [(path, label), ...]
self.label_array = None # 暂未使用,但可以扩展为用于存储标签的张量
self._get_img_info() # 调用私有方法,加载图像路径和标签信息
def __getitem__(self, index): # 输入标量index, 从硬盘中读取数据,并预处理
path_img, label = self.img_info[index] # 根据索引 index 从 self.img_info 中提取图像路径和标签。
img = Image.open(path_img).convert('L') # 加载图像文件,并转换为灰度格式(convert('L'))。
if self.transform is not None: # 如果定义了预处理操作(transform),对图像应用预处理。
img = self.transform(img) # 应用预处理操作
return img, label
def __len__(self): # 返回数据集的样本数量,即 self.img_info 的长度。
if len(self.img_info) == 0:
raise Exception(f"\ndata_dir:{self.root_dir} is a empty dir! ")
return len(self.img_info)
def _get_img_info(self): # 将硬盘中的数据路径,标签读取进来,存在一个list中
with open(self.txt_path, "r") as f: # 读取txt,解析txt
txt_data = f.read().strip()
txt_data = txt_data.split("\n")
for i in txt_data:
self.img_info.append( (os.path.join(self.root_dir, i.split()[0]), int(i.split()[2])) )
# 元组的前者是 将图像的相对路径拼接为绝对路径,后者是 提取图像对应的标签
root_dir = "./data/cov_19_demo" # 数据的根目录
img_dir = os.path.join(root_dir, "imgs") # 图像所在的目录
path_txt_train = os.path.join(root_dir, "labels", "train.txt") # 训练数据集的标签
path_txt_valid = os.path.join(root_dir, "labels", "valid.txt") # 验证数据集的标签
transforms_func = transforms.Compose([
transforms.Resize((8, 8)), # 将图像调整为 8x8 像素大小。
transforms.ToTensor(), # 转换为张量并将像素值从 [0, 255] 归一化到 [0, 1]
])
train_data = COVID19Dataset(root_dir=img_dir, txt_path=path_txt_train, transform=transforms_func)
valid_data = COVID19Dataset(root_dir=img_dir, txt_path=path_txt_valid, transform=transforms_func)
# train_loader = DataLoader(dataset=train_data, batch_size=2)
# valid_loader = DataLoader(dataset=valid_data, batch_size=2)
上述代码运行流程如下:
- 主函数定义好 图像目录、测试验证集和验证数据集的目录
- 定义要对每一个样本进行的预处理操作
- 传入参数(地址和预处理操作),创建训练数据集和 测试数据集
- 调用init方法,完成初始化,然后 自动调用_get_img_info() 方法,读取所有标签并 存在属性self.img_info中
- 代码结束后,两个数据集的 img_info 属性中应当存了 一个包含样本地址和标签的表格
注意:这里 __getitem__ 和 __len__ 方法还没有被调用。
__len__ 用于返回数据集长度, __getitem__ 用于通过索引访问单个样本
触发方法如下:
# 触发 __getitem__
img, label = train_data[0]
print(f"第一个样本的图像大小: {img.shape}, 标签: {label}")
# 触发 __len__
print(f"训练集样本数量: {len(train_data)}")
# 第一个样本可视化
import matplotlib.pyplot as plt
plt.imshow(img.squeeze(), cmap="gray")
plt.title(f"Label: {label}")
plt.show()
第一个样本的图像大小: torch.Size([1, 8, 8]), 标签: 1
训练集样本数量: 2

结合 DataLoader 调用
如果你将 train_data 传入 DataLoader 并迭代数据,这两个方法会被自动调用:
from torch.utils.data import DataLoader
# 创建数据加载器
train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
# 迭代加载数据
for images, labels in train_loader:
print(f"批次图像大小: {images.size()}, 标签: {labels}")
break # 打印一批后停止
批次图像大小: torch.Size([2, 1, 8, 8]), 标签: tensor([0, 1])
2. 案例二:数据的划分及标签在文件夹中
文件结构如下
.
├── train
│ ├── covid-19
│ │ └── ryct.2020200028.fig1a.jpeg
│ └── no-finding
│ └── 00001215_001.png
└── valid
├── covid-19
│ └── auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg
└── no-finding
└── 00001215_000.png
可以看到,没有专门的标签文件,而是直接把图片分好类放在了各个文件夹中
class COVID19Dataset_2(Dataset):
def __init__(self, root_dir, transform=None): #获取数据集的路径、预处理的方法
self.root_dir = root_dir
self.transform = transform
self.img_info = [] # [(path, label), ... , ]
self.label_array = None
self.str_2_int = {"no-finding": 0, "covid-19": 1} # 由于标签信息是string,需要转换为int类型
self._get_img_info()
def __getitem__(self, index): #输入标量index, 从硬盘中读取数据,并预处理,to Tensor
path_img, label = self.img_info[index]
img = Image.open(path_img).convert('L')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
if len(self.img_info) == 0:
raise Exception(f"\ndata_dir:{self.root_dir} is a empty dir! ")
return len(self.img_info)
def _get_img_info(self):
for root, dirs, files in os.walk(self.root_dir): # 遍历根目录下的所有子目录和文件
# os.walk() 递归遍历所有子目录及其中的文件,生成目录树
# root 是当前遍历的目录路径,dir是当前目录下的子目录名称列表(不包含路径,只是名称),file类似
print(f"当前扫描路径:{root},该路径下子文件夹:{dirs},子文件{files}")
for file in files:
if file.endswith("png") or file.endswith("jpeg"): # 仅处理 PNG 和 JPEG 格式的文件
path_img = os.path.join(root, file) # 拼接完整图像路径
sub_dir = os.path.basename(root) # 获取文件所在的子目录名称
label_int = self.str_2_int[sub_dir] # 将子目录名称映射为整数标签
self.img_info.append((path_img, label_int)) # 将路径和标签存入 img_info
数据集实例化
# 定义训练集和验证集的根目录
train_root_dir = "./data/covid-19-dataset-2/train"
valid_root_dir = "./data/covid-19-dataset-2/valid"
transforms_func = transforms.Compose([
transforms.Resize((8, 8)), # 调整图像大小为 8x8
transforms.ToTensor(), # 转换为张量并归一化到 [0, 1]
])
# 实例化训练集和验证集
train_dataset = COVID19Dataset_2(root_dir=train_root_dir, transform=transforms_func)
valid_dataset = COVID19Dataset_2(root_dir=valid_root_dir, transform=transforms_func)
当前扫描路径:./data/covid-19-dataset-2/train,该路径下子文件夹:['no-finding', 'covid-19'],子文件[] 当前扫描路径:./data/covid-19-dataset-2/train/no-finding,该路径下子文件夹:[],子文件['00001215_001.png'] 当前扫描路径:./data/covid-19-dataset-2/train/covid-19,该路径下子文件夹:[],子文件['ryct.2020200028.fig1a.jpeg'] 当前扫描路径:./data/covid-19-dataset-2/valid,该路径下子文件夹:['no-finding', 'covid-19'],子文件[] 当前扫描路径:./data/covid-19-dataset-2/valid/no-finding,该路径下子文件夹:[],子文件['00001215_000.png'] 当前扫描路径:./data/covid-19-dataset-2/valid/covid-19,该路径下子文件夹:[],子文件['auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg']
测试
print(f"数据集大小: {len(train_dataset)}") # 调用 __len__()
img, label = train_dataset[0] # 调用 __getitem__()
print(f"图像大小: {img.shape}, 标签: {label}")
# from torch.utils.data import DataLoader
# train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
# for images, labels in train_loader:
# print(f"批次图像大小: {images.size()}, 标签: {labels}")
# break
数据集大小: 2
图像大小: torch.Size([1, 8, 8]), 标签: 0
3. 案例三:数据的划分及标签在csv中
文件格式如下
.
├── dataset-meta-data.csv
└── imgs
├── 00001215_000.png
├── 00001215_001.png
├── auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg
└── ryct.2020200028.fig1a.jpeg
csv 文件中记录的内容格式为:
img-name,label,set-type,etc
ryct.2020200028.fig1a.jpeg,1,train,
class COVID19Dataset_3(Dataset):
def __init__(self, root_dir, path_csv, mode, transform=None): # 获取数据集的路径、预处理的方法。
# 由于数据划分体现在同一份文件中,因此需要设计 train/valid模式
self.root_dir = root_dir
self.path_csv = path_csv
self.mode = mode
self.transform = transform
self.img_info = [] # [(path, label), ... , ]
self.label_array = None
# 由于标签信息是string,需要一个字典转换为模型训练时用的int类型
self._get_img_info()
def __getitem__(self, index):
path_img, label = self.img_info[index]
img = Image.open(path_img).convert('L')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
if len(self.img_info) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(
self.root_dir))
return len(self.img_info)
def _get_img_info(self):
df = pd.read_csv(self.path_csv) # 加载 CSV 文件
df.drop(df[df["set-type"] != self.mode].index, inplace=True) # 筛选指定模式的数据集 (如 train 或 valid)
df.reset_index(inplace=True) # 重置索引,避免删除行后索引不连续
for idx in range(len(df)): # 遍历筛选后的 DataFrame
path_img = os.path.join(self.root_dir, df.loc[idx, "img-name"]) # 获取图像路径
# print(idx)
# print(path_img)
label_int = int(df.loc[idx, "label"]) # 获取标签并转换为整数
self.img_info.append((path_img, label_int)) # 添加到 img_info 列表中
方法解释
df.drop():删除与当前模式(mode)不匹配的数据行。
例如,如果当前模式是 train,则保留 set-type 为 train 的数据行。df.reset_index():删除行后,原索引可能不连续,reset_index 用于重置索引,避免后续访问时出错。os.path.join(self.root_dir, df.loc[idx, "img-name"]):
将根目录路径 root_dir 和图像文件名 img-name 拼接成完整路径。int(df.loc[idx, "label"]):将标签转换为对应的数字
实例化
# 数据集根目录
root_dir = "./data/covid-19-dataset-3/imgs"
# CSV 文件路径
path_csv = "./data/covid-19-dataset-3/dataset-meta-data.csv"
# 数据预处理操作
transforms_func = transforms.Compose([
transforms.Resize((8, 8)), # 调整大小为 8x8
transforms.ToTensor(), # 转换为张量并归一化到 [0, 1]
])
# 实例化训练集和验证集
train_data = COVID19Dataset_3(root_dir=root_dir, path_csv=path_csv, mode="train", transform=transforms_func)
valid_data = COVID19Dataset_3(root_dir=root_dir, path_csv=path_csv, mode="valid", transform=transforms_func)
# 查看数据集大小
print(f"训练集大小: {len(train_data)}")
print(f"验证集大小: {len(valid_data)}")
# 获取一个样本
img, label = train_data[0]
print(f"第一个样本的图像大小: {img.shape}, 标签: {label}")
训练集 大小: 2
验证集大小: 2
第一个样本的图像大小: torch.Size([1, 8, 8]), 标签: 1
4. 案例四:使用标准数据集
PyTorch 提供了一些常见的数据集,例如 MNIST、CIFAR-10 等,
这些数据集已经实现了 Dataset 的接口,可以直接使用。
MNIST 是一个经典的手写数字数据集
- 类别数:10 个类别,对应数字
0-9。 - 样本数量:
- 训练集:60,000 张图片。
- 测试集:10,000 张图片。
- 图像尺寸:28x28 像素。
- 通道数:单通道灰度图(只有一个颜色通道)。
- 特点:
- 图片已经被中心化,去除了背景噪声。
- 图片尺寸较小(28x28),适合快速训练和实验。
from torchvision import datasets, transforms
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 标准化到 [-1, 1]
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 查看数据集长度
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 获取一个样本
image, label = train_dataset[0]
print(f"第一个样本的标签: {label}")
训练集大小: 60000
测试集大小: 10000
第一个样本的标签: 5
(二)DataLoader
Dataset 是用于定义数据集的类(也就是如何读取数据),而 DataLoader 则用于加载、处理、组装、采样数据。
DataLoader 的主要功能:
-
支持两种类型的数据集读取方式
映射式 (Map-style):通过索引直接获取某条数据。比如之前用的Dataset通过getitem方法用序号获取数据。迭代式 (Iterable-style):这种方式是一次次从数据源中顺序取出数据,就像迭代一个列表。不过这种方式这里暂时不展开。
-
可以自定义数据采样规则
- 使用
Sampler,我们可以为每种类别分配不同的采样权重,比如当某些类别数据少时,可以调整权重让这些类别的样本多被采样,达到平衡采样的效果。 - 还可以完全自定义自己的采样逻辑。
- 使用
-
自动将数据整理成批次
- 数据会被自动分成小批次(mini-batch)以便训练,这样可以直接用在模型里进行批量计算。
- 如果有特殊需求,还可以通过参数
batch_sampler或collate_fn自定义批次数据的组合方式。
-
支持多进程并行加载数据
- 通常情况下,GPU 处理数据的速度比 CPU 加载数据的速度快,这会导致 GPU 等待数据加载,影响效率。
- 为了 缓解这个问题,可以设置多进程来并行加载数据。比如,如果你的 CPU 有 16 个核心,可以设置
num_workers=16来充分利用 CPU 资源,提高数据加载速度。
标准格式如下:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
参数解释:
dataset:一个Dataset实例, 确保要有getitem函数batch_size:每个batch的样本量shuffle:是否对打乱样本顺序。训练集通常要打乱它!验证集和测试集无所谓。sampler:设置采样策略。batch_sampler:设置采样策略, batch_sampler与sampler二选一num_workers: 设置多少个子进程进行数据加载,需要和CPU的核数一样collate_fn:组装数据的规则, 决定如何将一批数据组装起来。pin_memory:是否使用锁页内存,drop_last:每个epoch是否放弃最后一批不足batchsize大小的数据,
不代表这些剩下的数据被舍弃了,因为每个epoch,dataloader的采样都会重新shuffle打乱
import os
cpu_count = os.cpu_count()
print(f"CPU 核数:{cpu_count}")
CPU 核数:12
我们这里采用了上面Dataset里的 案例二
创建 Dataloader
# 定义训练集和验证集的 DataLoader
train_loader = DataLoader(
dataset=train_dataset,
batch_size=4, # 每批加载 4 个样本
shuffle=True, # 打乱训练数据
num_workers=12
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=4,
shuffle=False, # 验证数据无需打乱
num_workers=12
)
# print(train_loader)
转到GPU上
PyTorch 默认情况下会在 CPU 上运行,除非明确指定使用 GPU。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"运行设备: {device}")
运行设备: cuda
遍历训练集和验证集
# 遍历训练集
print("训练集样本:")
for batch_idx, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device) # 将数据转移到 GPU
print(f"批次 {batch_idx + 1}:")
print(f" - 图像大小: {images.shape}")
print(f" - 标签: {labels}")
break
# 遍历验证集
print("\n验证集样本:")
for batch_idx, (images, labels) in enumerate(valid_loader):
images, labels = images.to(device), labels.to(device) # 将数据转移到 GPU
print(f"批次 {batch_idx + 1}:")
print(f" - 图像大小: {images.shape}")
print(f" - 标签: {labels}")
break
训练集样本:
批次 1:
- 图像大小: torch.Size([2, 1, 8, 8])
- 标签: tensor([1, 0], device='cuda:0')
验证集样本:
批次 1:
- 图像大小: torch.Size([2, 1, 8, 8])
- 标签: tensor([0, 1], device='cuda:0')
注意:
在 Windows 上,需要将代码封装在 if name == "main": 中,否则会导致多线程加载出错。